package util

import (
	"context"
	"errors"
	"github.com/gomodule/redigo/redis"
	"golang.org/x/sync/singleflight"
	"net/url"
	"runtime"
	"strings"
	"time"
)

type Redis struct {
	Pool    *redis.Pool
	logger  func(err error, cmd string, args ...any)
	sfg     *singleflight.Group
	Dsn     string
	version string
	Options []redis.DialOption
}

var ErrWrongRedisDsn = errors.New("wrong redis dsn")
var ErrWrongRedisExXx = errors.New("wrong redis ex xx")

// https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua
var _redisRateScript = redis.NewScript(1, `
redis.replicate_commands()
local rate_limit_key = KEYS[1]
local rate = ARGV[1]
local period = ARGV[2]
local emission_interval = period / rate
local burst_offset = emission_interval * rate
local now = redis.call("TIME")
now = (now[1] - 1483228800) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
  tat = now
else
  tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + emission_interval
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
  return 0
end
local reset_after = new_tat - now
if reset_after > 0 then
  redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
return 1
`)

// NewRedis 得到实例
func NewRedis(dsn string, options ...redis.DialOption) (*Redis, error) {
	if u, err := url.Parse(dsn); err != nil {
		return nil, err
	} else if u.Scheme != "redis" {
		return nil, ErrWrongRedisDsn
	}
	cpuNum := runtime.NumCPU()
	maxIdle := cpuNum
	if maxIdle < 20 {
		maxIdle = 20
	}
	maxActive := cpuNum * 5
	if maxActive < 100 {
		maxActive = 100
	}
	pool := &redis.Pool{
		Dial: func() (redis.Conn, error) {
			return redis.DialURL(dsn, options...)
		},
		DialContext: func(ctx context.Context) (redis.Conn, error) {
			select {
			case <-ctx.Done():
				return nil, ctx.Err()
			default:
				return redis.DialURL(dsn, options...)
			}
		},
		TestOnBorrow: func(c redis.Conn, t time.Time) error {
			if time.Since(t) < 3*time.Minute {
				return nil
			}
			_, err := c.Do("PING")
			return err
		},
		MaxIdle:         maxIdle,
		MaxActive:       maxActive,
		IdleTimeout:     3 * time.Minute,
		Wait:            true,
		MaxConnLifetime: 30 * time.Minute,
	}
	return &Redis{
		Pool:    pool,
		Dsn:     dsn,
		Options: options,
		sfg:     new(singleflight.Group),
	}, nil
}

// SetLogger 打日志
func (r *Redis) SetLogger(fn func(err error, cmd string, args ...any)) {
	r.logger = fn
}

// Do 底层执行
func (r *Redis) Do(cmd string, args ...any) (any, error) {
	conn := r.Pool.Get()
	defer func() {
		_ = conn.Close()
	}()
	ret, err := conn.Do(cmd, args...)
	if r.logger != nil {
		go r.logger(err, cmd, args...)
	}
	return ret, err
}

// DoContext 带上下文的执行
func (r *Redis) DoContext(ctx context.Context, cmd string, args ...any) (re any, err error) {
	c := make(chan struct{}, 1)
	go func() {
		re, err = r.Do(cmd, args...)
		close(c)
	}()
	select {
	case <-ctx.Done():
		err = ctx.Err()
	case <-c:
	}
	return
}

// DoWithTimeout 带超时的执行
func (r *Redis) DoWithTimeout(timeout time.Duration, cmd string, args ...any) (any, error) {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	defer cancel()
	return r.DoContext(ctx, cmd, args...)
}

// Version 获取服务端版本
func (r *Redis) Version() (string, error) {
	if r.version != "" {
		return r.version, nil
	}
	ret, err := redis.String(r.Do("info", "server"))
	if err != nil {
		return "", err
	}
	for _, s := range strings.Split(ret, "\n") {
		if strings.HasPrefix(s, "redis_version:") {
			r.version = strings.TrimSpace(strings.TrimLeft(s, "redis_version:"))
			break
		}
	}
	return r.version, nil
}

// RateLimit 限流
func (r *Redis) RateLimit(key string, rate uint, period time.Duration) (bool, error) {
	conn := r.Pool.Get()
	defer func() {
		_ = conn.Close()
	}()
	return redis.Bool(_redisRateScript.Do(conn, key, rate, period.Seconds()))
}

// Set 设置
func (r *Redis) Set(key string, val any, period time.Duration, NXorXX string) (bool, error) {
	args := make([]any, 0)
	args = append(args, key, val)
	if period > 0 {
		args = append(args, "PX", period.Milliseconds())
	}
	if NXorXX == "NX" || NXorXX == "XX" {
		args = append(args, NXorXX)
	} else if NXorXX != "" {
		return false, ErrWrongRedisExXx
	}
	ok, err := redis.String(r.Do("SET", args...))
	return ok == "OK", err
}

// Incr 自增
func (r *Redis) Incr(key string, step int) (int64, error) {
	return redis.Int64(r.Do("INCRBY", key, step))
}

// Decr 递减
func (r *Redis) Decr(key string, step int) (int64, error) {
	return redis.Int64(r.Do("DECRBY", key, step))
}

// Del 删除
func (r *Redis) Del(keys ...any) (int64, error) {
	return redis.Int64(r.Do("DEL", keys...))
}

// Exists 是否存在
func (r *Redis) Exists(key string) bool {
	re, _ := redis.Int(r.Do("EXISTS", key))
	return re == 1
}

// Expire 设置自动过期时间
func (r *Redis) Expire(key string, period time.Duration) (bool, error) {
	return redis.Bool(r.Do("PEXPIRE", key, period.Milliseconds()))
}

// ExpireAt 在指定的时间戳（秒）过期
func (r *Redis) ExpireAt(key string, at int64) (bool, error) {
	return redis.Bool(r.Do("EXPIREAT", key, at))
}

// HMSet 设置一个Map值
func (r *Redis) HMSet(key string, obj map[string]any) (bool, error) {
	args := make([]any, 0)
	args = append(args, key)
	for k, v := range obj {
		args = append(args, k, v)
	}
	ok, err := redis.String(r.Do("HMSET", args...))
	return ok == "OK", err
}

// HSet 设置Map值的一个Item
func (r *Redis) HSet(key string, k string, v any) (int64, error) {
	return redis.Int64(r.Do("HSET", key, k, v))
}

// HGet 获取Map值的一个Item
func (r *Redis) HGet(key string, k string) (string, error) {
	return redis.String(r.Do("HGET", key, k))
}

// HMGet HM GET
func (r *Redis) HMGet(key string, ks ...any) (map[string]string, error) {
	args := make([]any, 0)
	args = append(args, key)
	args = append(args, ks...)
	ss, err := redis.Strings(r.Do("HMGET", args...))
	if err != nil {
		return nil, err
	}
	re := make(map[string]string)
	for i, k := range ks {
		re[k.(string)] = ss[i]
	}
	return re, nil
}

// HGetAll 获取所有Map
func (r *Redis) HGetAll(key string) (map[string]string, error) {
	return redis.StringMap(r.Do("HGETALL", key))
}

// HGetAllOrSingleDo 防击穿获取数据
func (r *Redis) HGetAllOrSingleDo(key string, fn func(key string, redis *Redis) (any, error)) (map[string]string, error) {
	if fn == nil {
		return r.HGetAll(key)
	}
	re, err, _ := r.sfg.Do(key, func() (any, error) {
		re, err := r.Do("HGETALL", key)
		if err != nil {
			return nil, err
		}
		if re != nil {
			return re, nil
		}
		return fn(key, r)
	})
	return redis.StringMap(re, err)
}

// Get 获取数据
func (r *Redis) Get(key string) (any, error) {
	return r.Do("GET", key)
}

// GetOrSingleDo 防击穿获取数据
func (r *Redis) GetOrSingleDo(key string, fn func(key string, redis *Redis) (any, error)) (any, error) {
	if fn == nil {
		return r.Get(key)
	}
	re, err, _ := r.sfg.Do(key, func() (any, error) {
		re, err := r.Get(key)
		if err != nil {
			return nil, err
		}
		if re != nil {
			return re, nil
		}
		return fn(key, r)
	})
	return re, err
}

// GetStr 获取字符串数据
func (r *Redis) GetStr(key string) (string, error) {
	return redis.String(r.Get(key))
}

// GetInt 获取int64数据
func (r *Redis) GetInt(key string) (int64, error) {
	return redis.Int64(r.Get(key))
}

// GetNum 获取float64数据
func (r *Redis) GetNum(key string) (float64, error) {
	return redis.Float64(r.Get(key))
}

// GetBool 获取布尔数据
func (r *Redis) GetBool(key string) (bool, error) {
	return redis.Bool(r.Get(key))
}
